import numpy as np
import matplotlib.pyplot as plt
def data_to_counts(data):
''' Convert the data to an array of counts of the numbers
of occurrences of each outcome in each row.
'''
counts = [np.bincount(row, minlength=10) for row in data]
return np.asarray(counts)
def draw_assignments(counts, ps, qs):
''' Compute the conditional distributions of the assignemnts
of each row and then draw from these distributions.
'''
p_likelihoods = np.product(np.power(ps, counts), axis=1)
q_likelihoods = np.product(np.power(qs, counts), axis=1)
assignments_probs = p_likelihoods / (p_likelihoods + q_likelihoods)
assignments = np.random.uniform(size=len(assignments_probs)) < assignments_probs
return np.asarray(assignments)
def draw_probabilities(counts, assignments):
''' Compute the total counts in all rows assigned to each
type and draw new sets of rates for each type given
these counts.
'''
ps_rows = counts[np.where(assignments == True)]
ps_counts = ps_rows.sum(axis=0)
ps = np.random.dirichlet(ps_counts)
qs_rows = counts[np.where(assignments == False)]
qs_counts = qs_rows.sum(axis=0)
qs = np.random.dirichlet(qs_counts)
return ps, qs
num_cycles = 10000
data = np.loadtxt('data/gibbs_data.txt', dtype=int)
num_rows, num_outcomes = data.shape
counts = data_to_counts(data)
ps_history = np.empty((num_cycles, num_outcomes))
qs_history = np.empty((num_cycles, num_outcomes))
assignments_history = np.empty((num_cycles, num_rows), dtype=bool)
# Draw sensible starting values.
ps = np.random.dirichlet(np.ones(num_outcomes) * counts.sum() / 2)
qs = np.random.dirichlet(np.ones(num_outcomes) * counts.sum() / 2)
for i in xrange(num_cycles):
assignments = draw_assignments(counts, ps, qs)
ps, qs = draw_probabilities(counts, assignments)
assignments_history[i] = assignments
ps_history[i] = ps
qs_history[i] = qs
The posterior distributions of $p$ and $q$ are distinct.
fig, (ps_ax, qs_ax) = plt.subplots(1, 2, figsize=(16, 8))
components_to_plot = 2
for label, ax, values in [('p', ps_ax, ps_history), ('q', qs_ax, qs_history)]:
for j in range(components_to_plot):
ax.hist(values[:, j], bins=100, histtype='step', label='{0}[{1}]'.format(label, j))
ax.set_title('Posterior distribution of {0}'.format(label))
ax.legend()
np.mean(ps_history, axis=0)
array([ 0.08127113, 0.11901405, 0.07965701, 0.12636537, 0.1126423 , 0.08212299, 0.07992786, 0.12643449, 0.07814592, 0.1144189 ])
np.mean(qs_history, axis=0)
array([ 0.12044314, 0.08564833, 0.11900612, 0.07219582, 0.08464854, 0.11952155, 0.11796448, 0.07773247, 0.11755809, 0.08528147])
The confidence with which each rat is assigned to a specific cluster is typically not very extreme.
np.mean(assignments_history, axis=0)
array([ 0.6788, 0.1579, 0.5252, 0.6496, 0.2798, 0.941 , 0.5888, 0.8602, 0.3345, 0.5007, 0.283 , 0.1966, 0.167 , 0.1791, 0.7433, 0.5654, 0.3227, 0.7931, 0.9389, 0.4958, 0.6565, 0.6585, 0.2895, 0.2827, 0.9296, 0.4638, 0.41 , 0.6715, 0.3177, 0.5033, 0.0402, 0.4392, 0.8509, 0.5411, 0.1124, 0.5685, 0.7973, 0.044 , 0.701 , 0.8114, 0.3254, 0.1219, 0.0801, 0.341 , 0.1794, 0.7378, 0.0947, 0.7881, 0.8396, 0.6476, 0.3309, 0.9732, 0.8316, 0.675 , 0.1221, 0.2886, 0.8213, 0.8286, 0.8933, 0.8071, 0.8384, 0.8436, 0.1665, 0.6942, 0.2755, 0.1449, 0.5556, 0.128 , 0.4752, 0.9772, 0.4797, 0.4987, 0.6151, 0.5728, 0.1835, 0.3351, 0.9003, 0.1884, 0.2859, 0.1614, 0.4551, 0.7596, 0.8465, 0.096 , 0.2926, 0.2474, 0.8846, 0.5061, 0.6547, 0.972 , 0.8626, 0.5974, 0.706 , 0.4849, 0.2527, 0.8669, 0.7154, 0.4398, 0.8956, 0.3745, 0.8056, 0.1812, 0.2249, 0.3652, 0.5825, 0.975 , 0.5186, 0.7462, 0.7146, 0.0915, 0.0904, 0.2406, 0.0413, 0.5509, 0.1764, 0.5654, 0.1894, 0.7267, 0.6827, 0.4981, 0.1951, 0.9397, 0.2748, 0.6457, 0.2609, 0.3583, 0.8588, 0.4878, 0.259 , 0.3958, 0.1427, 0.5615, 0.6467, 0.4116, 0.8825, 0.1686, 0.0913, 0.3924, 0.6526, 0.0838, 0.727 , 0.7457, 0.6346, 0.284 , 0.6704, 0.421 , 0.6573, 0.8522, 0.977 , 0.2122, 0.0503, 0.0783, 0.5973, 0.3028, 0.3086, 0.6471, 0.2883, 0.8533, 0.1933, 0.8348, 0.7843, 0.4074, 0.2195, 0.0935, 0.3412, 0.3329, 0.2708, 0.1963, 0.8936, 0.367 , 0.748 , 0.584 , 0.7091, 0.4569, 0.772 , 0.4286, 0.9184, 0.5236, 0.7187, 0.0768, 0.4768, 0.0733, 0.8951, 0.6366, 0.1807, 0.165 , 0.2705, 0.2776, 0.562 , 0.784 , 0.2033, 0.8154, 0.9615, 0.3354, 0.2508, 0.292 , 0.2094, 0.2105, 0.3745, 0.2952, 0.9714, 0.4935, 0.8032, 0.1662, 0.1722, 0.2864, 0.2574, 0.1902, 0.2518, 0.1206, 0.0947, 0.8249, 0.3245, 0.434 , 0.7226, 0.76 , 0.7153, 0.0501, 0.3481, 0.0752, 0.6722, 0.1748, 0.328 , 0.7592, 0.8853, 0.541 , 0.3625, 0.8656, 0.8794, 0.2854, 0.2795, 0.2836, 0.7596, 0.1542, 0.9599, 0.0381, 0.7561, 0.5797, 0.0436, 0.4352, 0.316 , 0.2891, 0.5824, 0.703 , 0.5086, 0.3818, 0.4893, 0.4466, 0.4463, 0.3088, 0.9131, 0.7791, 0.2964, 0.7631, 0.068 , 0.4518, 0.0964, 0.5999, 0.1748, 0.8161, 0.3126, 0.277 , 0.8092, 0.6955, 0.0537, 0.5687, 0.0521, 0.7981, 0.3054, 0.4961, 0.173 , 0.5382, 0.2191, 0.5763, 0.3567, 0.3563, 0.6958, 0.4467, 0.9251, 0.498 , 0.8453, 0.3023, 0.653 , 0.1453, 0.5541, 0.4396, 0.6643, 0.353 , 0.2858, 0.6213, 0.4912, 0.3109, 0.4945, 0.8727, 0.5499, 0.0941, 0.9305, 0.5082, 0.7132, 0.6654, 0.6446, 0.6071, 0.7629, 0.8024, 0.2948, 0.1931, 0.2001, 0.0768, 0.1305, 0.2223, 0.042 , 0.1386, 0.3831, 0.1375, 0.1673, 0.1572, 0.5509, 0.9194, 0.3814, 0.7066, 0.7785, 0.8702, 0.5569, 0.8537, 0.4322, 0.261 , 0.519 , 0.7325, 0.0934, 0.5051, 0.8223, 0.9474, 0.437 , 0.5049, 0.1391, 0.2992, 0.3396, 0.7745, 0.6614, 0.1691, 0.8042, 0.9575, 0.2275, 0.164 , 0.5464, 0.8651, 0.1934, 0.0883, 0.6231, 0.2917, 0.3357, 0.488 , 0.1077, 0.1016, 0.2486, 0.5311, 0.6555, 0.332 , 0.9347, 0.2773, 0.5498, 0.4814, 0.3447, 0.7632, 0.8524, 0.3449, 0.1674, 0.8246, 0.0415, 0.4389, 0.288 , 0.504 , 0.6629, 0.711 , 0.7231, 0.389 , 0.3953, 0.6466, 0.1714, 0.3326, 0.9575, 0.9645, 0.7234, 0.9051, 0.8312, 0.7496, 0.6539, 0.6993, 0.9173, 0.7678, 0.1089, 0.1698, 0.1711, 0.8132, 0.1926, 0.6693, 0.9307, 0.0856, 0.6726, 0.7068, 0.4048, 0.8776, 0.0432, 0.8975, 0.478 , 0.102 , 0.2116, 0.5424, 0.515 , 0.8557, 0.9704, 0.5409, 0.6948, 0.4833, 0.4373, 0.0437, 0.2845, 0.7111, 0.4643, 0.1401, 0.487 , 0.5396, 0.3019, 0.0742, 0.2955, 0.6152, 0.6073, 0.7838, 0.6444, 0.8613, 0.8868, 0.7532, 0.2664, 0.684 , 0.1417, 0.2407, 0.0946, 0.4239, 0.8741, 0.6567, 0.1773, 0.1915, 0.8559, 0.2912, 0.4268, 0.7615, 0.5663, 0.9287, 0.303 , 0.6742, 0.7398, 0.5706, 0.6547, 0.6558, 0.0473, 0.8603, 0.6229, 0.8628, 0.7986, 0.8508, 0.2496, 0.7323, 0.2386, 0.5191, 0.2414, 0.4698, 0.2104, 0.3149, 0.5184, 0.591 , 0.1507, 0.1651, 0.564 , 0.1476, 0.7602, 0.347 , 0.4393, 0.6959, 0.5133, 0.1907, 0.7191, 0.8935, 0.3025, 0.6481, 0.5128, 0.3043, 0.5167, 0.3144, 0.0875, 0.4694, 0.7867, 0.2909, 0.2857, 0.5627, 0.2788, 0.4321, 0.3322, 0.2173, 0.3459, 0.9633, 0.2508, 0.6768, 0.3016, 0.2416, 0.0732, 0.1456, 0.7526, 0.0746, 0.8267, 0.0783, 0.0752, 0.9333, 0.7929, 0.5607, 0.3273, 0.7082, 0.6012, 0.6495, 0.2377, 0.176 , 0.2489, 0.9706, 0.6731, 0.799 , 0.0766, 0.7156, 0.4479, 0.1001, 0.5516, 0.9282, 0.9265, 0.8518, 0.0961, 0.4229, 0.8321, 0.3301, 0.8027, 0.2059, 0.7417, 0.3317, 0.551 , 0.5218, 0.7047, 0.1713, 0.4983, 0.4488, 0.0793, 0.7325, 0.292 , 0.5042, 0.7984, 0.1595, 0.3862, 0.7166, 0.0768, 0.9249, 0.9186, 0.3249, 0.4323, 0.7914, 0.5963, 0.5466, 0.9273, 0.1385, 0.1972, 0.1821, 0.4542, 0.5034, 0.6622, 0.0856, 0.8811, 0.5623, 0.5859, 0.4778, 0.6231, 0.9333, 0.9454, 0.9558, 0.4739, 0.2922, 0.1353, 0.7165, 0.7439, 0.7015, 0.7162, 0.605 , 0.3285, 0.532 , 0.388 , 0.654 , 0.9397, 0.3821, 0.2594, 0.7833, 0.6565, 0.5267, 0.5021, 0.5681, 0.4948, 0.9072, 0.7778, 0.2792, 0.7181, 0.4948, 0.6208, 0.5151, 0.7889, 0.1666, 0.3172, 0.5466, 0.5912, 0.4107, 0.3554, 0.3813, 0.097 , 0.2802, 0.0438, 0.2644, 0.9431, 0.501 , 0.6056, 0.5695, 0.0703, 0.0907, 0.5428, 0.5047, 0.1558, 0.1644, 0.907 , 0.5787, 0.4722, 0.5123, 0.8361, 0.1841, 0.555 , 0.9116, 0.5371, 0.8723, 0.6673, 0.1616, 0.4387, 0.2846, 0.8308, 0.4795, 0.7531, 0.4991, 0.7343, 0.276 , 0.2844, 0.659 , 0.7908, 0.1663, 0.5306, 0.1647, 0.1484, 0.5464, 0.6543, 0.0936, 0.7097, 0.8434, 0.879 , 0.8256, 0.2742, 0.2469, 0.1334, 0.6809, 0.4382, 0.1694, 0.1085, 0.4366, 0.8705, 0.5082, 0.4379, 0.4301, 0.4958, 0.6431, 0.433 , 0.0773, 0.4447, 0.7799, 0.5011, 0.1753, 0.5586, 0.894 , 0.8673, 0.168 , 0.1621, 0.6921, 0.5685, 0.7592, 0.4267, 0.8865, 0.6496, 0.4271, 0.932 , 0.7177, 0.4326, 0.7007, 0.639 , 0.7071, 0.2957, 0.0875, 0.1613, 0.9478, 0.8731, 0.3775, 0.1421, 0.6717, 0.3369, 0.8785, 0.6199, 0.0831, 0.8307, 0.5561, 0.8119, 0.828 , 0.431 , 0.4089, 0.8255, 0.1991, 0.5569, 0.3459, 0.9031, 0.2188, 0.1886, 0.1493, 0.3268, 0.5101, 0.5 , 0.1646, 0.5278, 0.0914, 0.4361, 0.1754, 0.2796, 0.2103, 0.8649, 0.9015, 0.1552, 0.6976, 0.2461, 0.8159, 0.8389, 0.2818, 0.2901, 0.2884, 0.3569, 0.141 , 0.7158, 0.656 , 0.1494, 0.1199, 0.9241, 0.1603, 0.9468, 0.6604, 0.0726, 0.9463, 0.7514, 0.7874, 0.5367, 0.4296, 0.9278, 0.3478, 0.3909, 0.3119, 0.8385, 0.8844, 0.3872, 0.733 , 0.4413, 0.0766, 0.841 , 0.4235, 0.7929, 0.4116, 0.8919, 0.5545, 0.8319, 0.7424, 0.2879, 0.862 , 0.9036, 0.9406, 0.6608, 0.2323, 0.9494, 0.0921, 0.4934, 0.28 , 0.8867, 0.2352, 0.5883, 0.2131, 0.2671, 0.5424, 0.7777, 0.3539, 0.4898, 0.4526, 0.4769, 0.9285, 0.7538, 0.6829, 0.2035, 0.4992, 0.6973, 0.9487, 0.7526, 0.7979, 0.6578, 0.802 , 0.3215, 0.2046, 0.8365, 0.6539, 0.9605, 0.0775, 0.8533, 0.1971, 0.1501, 0.3267, 0.6381, 0.1626, 0.9134, 0.2173, 0.753 , 0.5527, 0.7883, 0.8558, 0.2486, 0.8837, 0.2709, 0.2055, 0.3544, 0.1636, 0.204 , 0.1404, 0.7213, 0.6803, 0.3437, 0.5589, 0.6028, 0.0462, 0.5423, 0.8259, 0.6507, 0.1538, 0.7066, 0.2495, 0.617 , 0.3198, 0.1717, 0.8592, 0.4991, 0.8835, 0.7819, 0.6746, 0.7109, 0.5632, 0.6281, 0.3282, 0.7163, 0.8605, 0.0469, 0.7916, 0.5095, 0.3727, 0.0994, 0.9164, 0.4737, 0.7869, 0.7799, 0.4915, 0.4027, 0.7658, 0.0391, 0.9571, 0.6947, 0.5264, 0.489 , 0.5533, 0.0719, 0.7236, 0.8007, 0.5672, 0.6749, 0.9218, 0.3441, 0.0723, 0.6667, 0.6787, 0.6108, 0.0444, 0.8671, 0.5397, 0.5507, 0.1774, 0.3137, 0.3985, 0.5462, 0.5455, 0.1374, 0.9745, 0.3961, 0.8781, 0.9047, 0.435 , 0.3795, 0.7464, 0.1889, 0.5587, 0.1067, 0.0695, 0.7714, 0.3965, 0.1943, 0.8852, 0.0926, 0.2663, 0.7182, 0.1763, 0.2659, 0.3488, 0.2324, 0.5601, 0.2141, 0.1314, 0.3958, 0.5102, 0.7597, 0.9137, 0.9224, 0.91 , 0.1347, 0.2957, 0.749 , 0.6593, 0.8726, 0.1966, 0.215 , 0.1866, 0.8574, 0.6652, 0.79 , 0.2689, 0.2165, 0.6418, 0.7223, 0.4791, 0.7958, 0.8191, 0.7119, 0.2437, 0.5061, 0.1439, 0.6309, 0.3515, 0.1609, 0.3173, 0.1779, 0.1584, 0.0376, 0.2042, 0.315 , 0.3191, 0.7348, 0.1668, 0.4395, 0.7161, 0.8222, 0.1696, 0.5626, 0.4996, 0.092 , 0.6066, 0.6795, 0.75 , 0.3241, 0.2449, 0.9626, 0.1348, 0.9296, 0.8409, 0.3473, 0.5121, 0.7136, 0.9215, 0.3363, 0.0758, 0.0895, 0.768 , 0.6656, 0.4672, 0.293 , 0.6695, 0.1702, 0.8081, 0.4708, 0.1632])